from typing import List
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, AutoConfig, AutoModelForCausalLM
from .segment_anything_2.sam2.build_sam import build_sam2, build_sam2_video_predictor
from .unilm.beit3.modeling_utils import BEiT3Wrapper, _get_base_config, _get_large_config
from .configuration_evf import EvfConfig
from .segment_anything_2.sam2.utils.misc import load_video_frames
from collections import OrderedDict



def dice_loss(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    num_masks: float,
    scale=1000,  # 100000.0,
    eps=1e-6,
):
    """
    Compute the DICE loss, similar to generalized IOU for masks
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
    """
    inputs = inputs.sigmoid()
    inputs = inputs.flatten(1, 2)
    targets = targets.flatten(1, 2)
    numerator = 2 * (inputs / scale * targets).sum(-1)
    denominator = (inputs / scale).sum(-1) + (targets / scale).sum(-1)
    loss = 1 - (numerator + eps) / (denominator + eps)
    loss = loss.sum() / (num_masks + 1e-8)
    return loss


def sigmoid_ce_loss(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    num_masks: float,
):
    """
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
    Returns:
        Loss tensor
    """
    loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
    loss = loss.flatten(1, 2).mean(1).sum() / (num_masks + 1e-8)
    return loss

class EvfSam2Model(PreTrainedModel):
    config_class = EvfConfig
    def __init__(
        self,
        config,
        **kwargs
    ):
        super(EvfSam2Model, self).__init__(config)

        self.config = config
        self.vision_pretrained = kwargs.get("vision_pretrained", None)
        self.encoder_pretrained = kwargs.get("encoder_pretrained", None)
        self.dice_loss_weight = kwargs.get("dice_loss_weight", None)
        self.bce_loss_weight = kwargs.get("bce_loss_weight", None)
        self.train_mask_decoder = kwargs.get("train_mask_decoder", False)
        self.train_prompt_encoder = kwargs.get("train_prompt_encoder", False)
        self.initialize_evf_modules(config)
        self._bb_feat_sizes = [
            (256, 256),
            (128, 128),
            (64, 64),
        ]

    def initialize_evf_modules(self, config):
        # SAM
        if config.sam_scale=="large":
            self.visual_model = build_sam2_video_predictor("sam2_hiera_l.yaml", self.vision_pretrained, device=None)
        elif config.sam_scale=="tiny":
            self.visual_model = build_sam2_video_predictor("sam2_hiera_t.yaml", self.vision_pretrained, device=None)
        else:
            raise NotImplementedError
        
        for param in self.visual_model.parameters():
            param.requires_grad = False
        if self.train_mask_decoder:
            self.visual_model.sam_mask_decoder.train()
            for param in self.visual_model.sam_mask_decoder.parameters():
                param.requires_grad = True
        if self.train_prompt_encoder:
            self.visual_model.sam_prompt_encoder.no_mask_embed.requires_grad_(True)
            
        # beit-3
        if self.config.mm_extractor_scale == "base":
            beit_config = _get_base_config()
        elif self.config.mm_extractor_scale == "large":
            beit_config = _get_large_config()
        else:
            raise AttributeError(f"model config should contain key 'mm_extractor_scale', with value 'base' or 'large'.")

        self.mm_extractor = BEiT3Wrapper(beit_config)
        if self.encoder_pretrained is not None:
            beit_state_dict = torch.load(self.encoder_pretrained)["model"]
            self.mm_extractor.load_state_dict(
                beit_state_dict, 
                strict=False
            )

        for param in self.mm_extractor.parameters():
            param.requires_grad = True
                
        # Projection layer
        in_dim = config.hidden_size
        assert in_dim==beit_config.encoder_embed_dim, \
            f"projection layer dim {in_dim} mismatch with mm_extractor dim {beit_config.encoder_embed_dim}"
        out_dim = config.out_dim
        text_fc = [
            nn.Linear(in_dim, in_dim),
            nn.ReLU(),
            nn.Linear(in_dim, out_dim)
        ]
        self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)])
        self.text_hidden_fcs.train()
        for param in self.text_hidden_fcs.parameters():
            param.requires_grad = True


    def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
        """
        Perform PostProcessing on output masks.
        """
        masks = masks.float()
        masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
        return masks

    # def forward(
    #     self,
    #     images: torch.FloatTensor,
    #     images_evf: torch.FloatTensor,
    #     input_ids: torch.LongTensor,
    #     attention_masks: torch.LongTensor,
    #     offset: torch.LongTensor,
    #     masks_list: List[torch.FloatTensor],
    #     label_list: List[torch.Tensor],
    #     resize_list: List[tuple],
    #     inference: bool = False,
    #     **kwargs,
    # ):
    #     # image_embeddings = self.get_visual_embs(images)     
    #     backbone_out = self.visual_model.forward_image(images)
    #     # dict_keys(['vision_features', 'vision_pos_enc', 'backbone_fpn'])
    #     _, image_embeddings, _, _ = self.visual_model._prepare_backbone_features(backbone_out)
    #     image_embeddings = [_.to(images.dtype) for _ in image_embeddings]
    #     batch_size = images.shape[0]
    #     if self.visual_model.directly_add_no_mem_embed:
    #         image_embeddings[-1] = image_embeddings[-1] + self.visual_model.no_mem_embed

    #     feats = [
    #         feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
    #         for feat, feat_size in zip(image_embeddings[::-1], self._bb_feat_sizes[::-1])
    #     ][::-1]
    #     _features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
        

    #     assert batch_size == len(offset) - 1

    #     images_evf_list = []
    #     for i in range(len(offset) - 1):
    #         start_i, end_i = offset[i], offset[i + 1]
    #         images_evf_i = (
    #             images_evf[i]
    #             .unsqueeze(0)
    #             .expand(end_i - start_i, -1, -1, -1)
    #             .contiguous()
    #         )
    #         images_evf_list.append(images_evf_i)
    #     images_evf = torch.cat(images_evf_list, dim=0)

    #     multimask_output = False
    #     output = self.mm_extractor.beit3(
    #         visual_tokens=images_evf, 
    #         textual_tokens=input_ids, 
    #         text_padding_position=~attention_masks
    #         )

    #     feat = output["encoder_out"][:, :1, ...]

    #     feat = self.text_hidden_fcs[0](feat)
    #     feat = torch.split(feat, [offset[i+1] - offset[i] for i in range(len(offset)-1)])

    #     pred_masks = []

    #     for i in range(len(feat)):
    #         (
    #             sparse_embeddings,
    #             dense_embeddings,
    #         ) = self.visual_model.sam_prompt_encoder(
    #             points=None,
    #             boxes=None,
    #             masks=None,
    #             text_embeds=feat[i],
    #         )
    #         sparse_embeddings = sparse_embeddings.to(feat[i].dtype)
    #         high_res_features = [
    #             feat_level[i].unsqueeze(0)
    #             for feat_level in _features["high_res_feats"]
    #         ]
    #         low_res_masks, iou_predictions, _, _ = self.visual_model.sam_mask_decoder(
    #             image_embeddings=_features["image_embed"][i].unsqueeze(0),
    #             image_pe=self.visual_model.sam_prompt_encoder.get_dense_pe(),
    #             sparse_prompt_embeddings=sparse_embeddings,
    #             dense_prompt_embeddings=dense_embeddings,
    #             multimask_output=multimask_output,
    #             repeat_image = True,
    #             high_res_features=high_res_features,
    #         )

    #         if multimask_output:
    #             sorted_ids = torch.argsort(iou_predictions, dim=-1, descending=True)
    #             low_res_masks = torch.take_along_dim(low_res_masks, sorted_ids[..., None, None], dim=1)[:, :1]
          
    #         pred_mask = self.postprocess_masks(
    #             low_res_masks,
    #             orig_hw=label_list[i].shape,
    #         )
    #         pred_masks.append(pred_mask[:, 0])

    #     gt_masks = masks_list

    #     if inference:
    #         return {
    #             "pred_masks": pred_masks,
    #             "gt_masks": gt_masks,
    #         }

    #     mask_bce_loss = 0
    #     mask_dice_loss = 0
    #     num_masks = 0
    #     for batch_idx in range(len(pred_masks)):
    #         gt_mask = gt_masks[batch_idx]
    #         pred_mask = pred_masks[batch_idx]

    #         assert (
    #             gt_mask.shape[0] == pred_mask.shape[0]
    #         ), "gt_mask.shape: {}, pred_mask.shape: {}".format(
    #             gt_mask.shape, pred_mask.shape
    #         )
    #         mask_bce_loss += (
    #             sigmoid_ce_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0])
    #             * gt_mask.shape[0]
    #         )
    #         mask_dice_loss += (
    #             dice_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0])
    #             * gt_mask.shape[0]
    #         )
    #         num_masks += gt_mask.shape[0]

    #     mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (num_masks + 1e-8)
    #     mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8)
    #     mask_loss = mask_bce_loss + mask_dice_loss

    #     loss = mask_loss

    #     return {
    #         "loss": loss,
    #         "mask_bce_loss": mask_bce_loss,
    #         "mask_dice_loss": mask_dice_loss,
    #         "mask_loss": mask_loss,
    #     }

    def inference(
            self,
            video_path,
            images_evf,
            input_ids,
            # original_size_list,
            multimask_output=False,
        ):
        predictor = self.visual_model
        inference_state = predictor.init_state(video_path=video_path)
        predictor.reset_state(inference_state)

       
        multimask_output = multimask_output

        output = self.mm_extractor.beit3(visual_tokens=images_evf, textual_tokens=input_ids, text_padding_position=torch.zeros_like(input_ids))

        feat = output["encoder_out"][:, :1, ...]
        feat = self.text_hidden_fcs[0](feat)

        ann_frame_idx = 0  # the frame index we interact with
        ann_obj_id = 1  # give a unique id to each object we interact with (it can be any integers)

        _, out_obj_ids, out_mask_logits = predictor.add_new_text(
            inference_state=inference_state,
            frame_idx=ann_frame_idx,
            obj_id=ann_obj_id,
            text=feat
        )

        # run propagation throughout the video and collect the results in a dict
        video_segments = {}  # video_segments contains the per-frame segmentation results
        for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
            video_segments[out_frame_idx] = {
                out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
                for i, out_obj_id in enumerate(out_obj_ids)
            }

        return video_segments
  

AutoConfig.register("evf", EvfConfig)
AutoModelForCausalLM.register(EvfConfig, EvfSam2Model)